/*
 * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package vfile.face;

import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.translate.TranslateException;
import util.ConverterImg;
import util.ImageUI;
import util.WorkId;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_imgcodecs;
import org.bytedeco.opencv.global.opencv_imgproc;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Point2f;
import org.bytedeco.opencv.opencv_core.Rect;
import org.bytedeco.opencv.opencv_core.Size;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class FeatureComparison {

    private static final Logger logger = LoggerFactory.getLogger(FeatureComparison.class);

    private FeatureComparison() {}

    public static void main(String[] args) throws Exception {
       /* if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
            logger.info("This example only works for PyTorch.");
            return;
        }*/

        Path imageFile1 = Paths.get("src/main/resources/img/kana1.jpg");
       // imageFile1 = Paths.get("src/main/resources/img/largest_selfie.jpg");
        Mat face = opencv_imgcodecs.imread(imageFile1.toString()/*, opencv_imgcodecs.IMREAD_GRAYSCALE*/);
        Image img1 = ImageFactory.getInstance().fromImage(ConverterImg.m2B(face));//.fromFile(imageFile1);
        Path imageFile2 = Paths.get("src/main/resources/img/kana2.jpg");
       // imageFile2 = Paths.get("src/main/resources/img/largest_selfie.jpg");
        Mat face2 = opencv_imgcodecs.imread(imageFile2.toString()/*, opencv_imgcodecs.IMREAD_GRAYSCALE*/);
        Image img2 = ImageFactory.getInstance().fromImage(ConverterImg.m2B(face2));

        
       
        
        //获取头像 放射变换
        DetectedObjects box1 = RetinaFaceOpencvDetection.predict(img1);
        DetectedObjects box2 = RetinaFaceOpencvDetection.predict(img2);
        
        int imageWidth1 = img1.getWidth();
        int imageHeight1 = img1.getHeight();
        
        int imageWidth2 = img2.getWidth();
        int imageHeight2 = img2.getHeight();
        
        List<DetectedObjects.DetectedObject> list1 = box1.items();
        List<Mat> faces1 = faces(face,list1,imageWidth1,imageHeight1);
        
        List<DetectedObjects.DetectedObject> list2 = box2.items();
        List<Mat> faces2 = faces(face2,list2,imageWidth2,imageHeight2); 
        
        FeatureExtraction.init();
        
        for(Mat ff : faces1){
        	for(Mat fs : faces2){ 
        		float[] feature1 = FeatureExtraction.predict(ImageFactory.getInstance().fromImage(ConverterImg.m2B(ff)));
                float[] feature2 = FeatureExtraction.predict(ImageFactory.getInstance().fromImage(ConverterImg.m2B(fs))); 
                logger.info(Float.toString(calculSimilar(feature1, feature2)));
        	}
        } 
    }

    
    public static List<Mat> faces(Mat src,List<DetectedObjects.DetectedObject> list,int imageWidth1,int imageHeight1) throws IOException{
    	 //保存路径
        Path outputDir = Paths.get("build/output");
        Files.createDirectories(outputDir);
    	List<Mat> matFaces = new ArrayList<Mat>();
        for (DetectedObjects.DetectedObject result : list) {
            String className = result.getClassName();
            BoundingBox box = result.getBoundingBox();
            if (box instanceof Landmark) {
            	Rectangle rectangle = box.getBounds();
                int x = (int) (rectangle.getX() * imageWidth1);
                int y = (int) (rectangle.getY() * imageHeight1); 
                
            	Rect rect = new Rect(x,y,(int) (rectangle.getWidth() * imageWidth1),(int) (rectangle.getHeight() * imageHeight1));
            	
            	List<Point> points = new ArrayList<Point>();
            	for (Point landmarks : box.getPath()) {
            		points.add(landmarks);
            	}
        		// 安装两眼之间的中点进行 仿射变换
        		//Point2f eyesCenter = new Point2f( Double.valueOf((points.get(0).getX()+points.get(1).getX())*0.5).floatValue(),Double.valueOf((points.get(0).getY()+points.get(1).getY())*0.5).floatValue());          
        		//安装鼻子为中心点进行仿射 变换
            	Point2f eyesCenter = new Point2f( Double.valueOf(points.get(2).getX()).floatValue(),Double.valueOf(points.get(2).getY()).floatValue());          

        		
        		
        		// 计算两个眼睛间的角度    
        		double dy = (points.get(1).getY() - points.get(0).getY());     
        		double dx = (points.get(1).getX() - points.get(0).getX()); 
        		
        		double angle = Math.atan2(dy, dx) * 180.0/opencv_core.CV_PI;  
        		// Convert from radians to degrees.          
        		//由eyesCenter, andle, scale按照公式计算仿射变换矩阵，此时1.0表示不进行缩放     
        		Mat rot_mat =opencv_imgproc.getRotationMatrix2D(eyesCenter, angle, 1.0);     
        		Mat rot = new Mat();  
        		if(rect.x() < 0 || rect.y() < 0 ||  rect.y() + rect.height() > src.rows() || rect.x()+rect.width() > src.cols() ){
        			continue;
        		} 
        		 
        		Mat saveFace = new Mat(src, rect);  
        		//仿射变换后 图片有些信息丢失了。应该想办法解决一下。
        		// 进行仿射变换，变换后大小为src的大小      
        		opencv_imgproc.warpAffine(saveFace, rot, rot_mat, saveFace.size());
        		
        		opencv_imgproc.resize(rot, rot, new Size(224,224));
        		
        		matFaces.add(rot); 
                // Make image copy with alpha channel because original image was jpg 
        		opencv_imgcodecs.imwrite(outputDir.resolve(WorkId.sortUID()+".jpg").toString(),rot);
            } 
        }
        return matFaces;
    }
    
    
    public static float calculSimilar(float[] feature1, float[] feature2) {
        float ret = 0.0f;
        float mod1 = 0.0f;
        float mod2 = 0.0f;
        int length = feature1.length;
        for (int i = 0; i < length; ++i) {
            ret += feature1[i] * feature2[i];
            mod1 += feature1[i] * feature1[i];
            mod2 += feature2[i] * feature2[i];
        }
        return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
    }
}
